In [1]:
import torch
from torch.autograd import Variable

In [2]:
dtype = torch.FloatTensor
N, D_in, H, D_out = 64, 1000, 100, 10

In [3]:
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)

In [5]:
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

In [6]:
x


Out[6]:
Variable containing:
-8.4927e-01  3.7209e-01  2.0565e-01  ...   2.6034e-01 -3.9997e-01 -2.0546e-01
-1.5018e+00 -8.2267e-01  2.9737e-01  ...   3.0785e-01 -6.0900e-01 -2.3617e+00
 1.4574e-01 -7.1956e-01 -9.4770e-01  ...  -1.0030e+00 -1.2861e-01 -2.0324e+00
                ...                   ⋱                   ...                
 1.4831e-01  2.0855e+00 -3.3356e-01  ...  -1.7427e+00  2.0441e+00 -1.8278e+00
 9.9977e-01 -9.7019e-01 -4.9967e-01  ...   1.1089e-01 -7.7166e-01 -8.2606e-02
-1.3062e+00  1.5458e+00 -1.6103e+00  ...  -1.7312e+00  1.1014e+00  5.4204e-01
[torch.FloatTensor of size 64x1000]

In [7]:
y


Out[7]:
Variable containing:
-0.7039  0.2815  0.4654 -1.8338 -1.0907 -0.9689 -1.2769  1.1185  0.1065  0.3366
 1.6305  0.9494  0.1230  0.2230  0.1182  0.4734  1.6509 -0.4793  0.0398  0.9409
-0.0052  0.3741 -0.0963  0.7438  0.7005 -0.0385 -1.0883 -0.0808 -0.1347 -0.1771
-0.1520  0.9047 -1.2712  0.9576  0.9655 -0.0732  0.5959  0.1848  0.0715  1.3470
-0.1557 -0.2296 -1.1731 -0.4178 -1.5174 -0.0211  0.9244  0.0867  0.3840 -1.2797
 0.4224 -1.0761  0.3312 -0.5842  0.0137 -0.6990  0.1891 -0.4143  2.2656 -0.9253
-1.2221 -0.4451  0.3090  0.6068 -0.4178 -0.7099 -0.1446  0.1633  0.1872 -1.2311
 0.5510 -0.2209  0.7475 -2.3571 -0.3969 -0.2426 -0.6853  1.1397 -1.9601  0.9470
-0.3749  0.0898  0.1697 -0.8002  0.6721 -1.2438  0.3401 -0.1466  0.9673  0.8664
-0.6045  0.7011 -1.3436 -0.3288  1.3448 -0.9595 -0.2083 -1.3035 -1.0293 -0.6756
-1.2699  0.6423 -1.3874  0.5259  0.2372  0.4391 -0.8772 -0.6586  0.5345  0.1963
 0.1210  1.6324 -0.1212  1.1554  0.9837 -0.8236 -0.4149 -0.3561  0.5146 -0.5968
-1.0086  0.4015  0.4426  0.2514  0.9107 -0.2473  1.7732 -0.9071  1.5172 -0.0594
-0.6084  0.1823 -0.0286 -0.9183  0.6871  0.0417 -1.5739 -0.8883  0.3590 -0.8505
 0.3084 -0.5938 -0.7139  1.5864  0.5191  0.6060 -2.1674 -1.7111  2.2235 -1.3032
-1.5008  0.3654 -0.5652  0.3552  0.4964 -0.7338 -0.7016 -1.8264  1.3548 -0.3621
 0.1020  1.4225 -0.0351  0.1048 -1.5217 -0.0522  0.0987  0.0321 -0.8761 -1.0961
 0.2908  1.2383  0.9669 -0.3732 -0.3580 -0.2620 -0.2158 -0.7708  0.5672 -0.5540
-0.5792  1.6062  1.1027 -0.6294  0.4294  0.9325 -0.5570 -0.7013 -0.7874 -0.2914
-0.6482 -0.9598  1.7271 -0.2496 -0.3051  0.2028 -0.5697  0.9403 -0.0792  1.2411
-0.3145  0.8643  0.4766  0.2900  0.1042  0.2783 -2.3802  0.7354  0.4622 -1.8160
 0.8366  0.5774  0.2863  0.5907 -0.5974 -0.7947 -0.7936 -0.4447 -0.3449  0.0895
 1.7629  0.3654 -0.8603 -0.6831  1.4193 -0.3151  1.4668 -0.2668 -0.7695 -0.4308
 0.8032 -0.1525 -1.1127 -1.3287 -0.5197 -0.1606 -0.1158 -0.9265  0.2989 -0.7426
-2.7512 -1.1786  0.3612 -0.9887  0.5813  0.1033 -0.8363  0.2216  0.3329  0.2135
-0.1291 -1.5119  0.6031 -1.3076 -1.1493 -1.3408  0.5302  0.6673  0.9512  1.0100
 0.6418 -1.1975  0.2449  1.7920  0.2127  0.3687 -1.9815 -0.4147 -0.5053  0.4923
-0.9598 -0.5152 -0.0972 -1.7389 -1.3450 -1.2923  1.0217  0.3225 -0.8006  0.2920
-0.5514  1.0085 -0.7528  0.6311 -0.0054 -0.7118  0.3161  0.7663 -0.4690  2.0582
-0.6338  1.6358  2.6860  1.3020  0.6157 -2.4639 -1.2917  0.4550 -1.3216  0.4453
-0.3183  0.1126 -0.2307 -0.5918 -0.7837  0.1659  0.8489  0.4408 -0.3900 -0.7129
 1.1971  0.8691  1.4649 -0.4326  1.6905  0.8588  0.8009 -0.7682  1.8111  1.5896
 0.7327  1.6515  0.6063  0.0619  0.0160  0.6064 -0.5834 -0.5514 -0.2470  0.4528
 0.9811  0.4572  1.2598 -0.2092 -0.7651  0.6956  0.0124  0.1255 -0.2539 -0.3896
 0.6420 -0.3476  0.1928  0.7596  0.4760  0.4884  2.2097  0.9823 -0.5446 -1.3484
-0.2843  0.2847  0.2663 -0.4300 -0.2792  1.5623 -1.1825 -0.8640 -0.6206  0.0206
-0.1187 -0.9158 -0.4900  1.0318  1.0442  1.2299  0.0395  0.9585 -1.7452  0.7026
 0.8861 -1.4200  0.1431 -1.0353  0.7656 -2.8105 -0.2938 -0.2858  0.9408 -1.8655
-2.0264 -1.5559 -0.8512 -1.8108 -1.8393 -0.6627 -0.5765  1.5017  0.1778 -2.8535
-0.8935  2.0577 -2.0361 -0.0905 -0.6137  0.6441 -0.1519  0.7535 -1.2513  0.6282
 0.2762 -0.2705 -0.7475 -0.4987  3.0299  0.1079  0.4253  0.9719  0.9451  0.3477
-0.4592  1.2081 -1.1140  1.2849  0.9143  0.1919  0.2821 -0.1014 -1.6709  1.4220
-0.7135  0.4135 -0.9898 -0.3129 -2.0325  1.2346  2.4821 -0.0923  1.8542 -1.0079
-0.2313  1.6956 -1.2057  0.4332  1.2227 -1.1784 -0.8522 -0.6338  1.4732 -0.4545
 0.8137  0.0266  1.0189 -0.7400 -1.0459 -1.7399 -0.5947  0.6087  0.2397  2.3152
 1.4909  0.7266  0.9611  0.9264  0.9975  1.0324 -0.0145 -0.2616 -0.8474  0.9146
-1.5231 -1.9571 -0.7830 -0.4465  0.6569  1.1156  0.6144  1.1333  0.2836 -0.3919
-1.1313 -0.6387 -0.7682  1.2239  0.7434  0.6707  0.1824  0.3920 -0.5081 -0.6057
-1.3370 -1.4948  0.9618 -0.8705  0.4390 -1.4316  0.0785  0.8036 -0.9607  0.3971
-1.0688 -0.2499 -1.9493  1.2108 -0.3599 -1.3142  0.0530  1.2412 -1.2947 -1.6781
-0.1321 -0.5509 -0.9050  1.0304  1.1261  0.1309  0.4431 -1.0128 -1.1900  0.9863
 0.6159  0.7466  1.3309  0.8241 -0.9954  0.3605  0.5362 -0.4171  0.1221  1.3147
 0.2097  2.6855 -0.2748  1.1670 -0.8673 -0.2316  1.9613  0.1863 -0.6195 -0.6429
 0.8443 -1.3725 -1.5043  1.9420  1.0535  0.4465  1.6409 -0.2299 -0.3553  0.2407
-0.4148 -1.6614 -0.4442 -0.9268  0.3548  0.7016  1.0779  0.3846 -0.9008 -0.6663
 0.2294 -1.0174 -1.3140  0.5548 -1.2077 -0.0527 -0.5051 -0.9171  1.7714  0.4038
-0.2566 -0.2060  0.3965 -0.8422  0.8616 -0.9902  0.4020  0.9346  0.6553  1.0078
-1.1111 -1.1017 -1.1860 -0.1094 -1.2254  0.7518  0.1536 -2.4607  1.0281  0.1950
 0.7814 -0.4105  1.1393  1.0611 -0.7637 -1.1841 -0.9820 -0.8033 -0.3506 -0.6181
-0.2276  0.5749 -1.6167 -1.2052 -1.2244  1.1740  0.3257  0.8713  1.5069  0.0978
-0.1131 -0.4294  0.6857 -0.3934  1.3345 -1.3020  1.3661 -1.1960  0.1325 -1.3745
-0.0550  0.2227 -0.9485 -0.7635  1.0899  0.2006  1.0376 -0.5738  0.4041  0.4747
-1.1798 -1.0384  1.1220  0.8526 -0.3145 -0.4695 -1.2921 -0.7177 -1.3901 -0.5076
-0.8907 -0.2934  0.1793 -0.1605 -0.6742 -1.8133  0.6108  0.3182 -0.3543 -1.8009
[torch.FloatTensor of size 64x10]

In [8]:
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

In [9]:
learning_rate = 1e-6

In [22]:
for t in range(200):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    # clamp(min=0)은 ReLU. clamp()은 값을 특정 범위로 묶는 것
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])
    
    loss.backward()
    
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data
    
    w1.grad.data.zero_()
    w2.grad.data.zero_()


0 26263210.0
1 21534024.0
2 22320400.0
3 25526674.0
4 28327110.0
5 27551494.0
6 22222566.0
7 14584345.0
8 8201597.5
9 4270791.5
10 2283170.25
11 1336163.875
12 881567.9375
13 645544.8125
14 508967.25
15 419834.71875
16 355584.125
17 305931.59375
18 265800.28125
19 232490.84375
20 204397.75
21 180446.40625
22 159900.75
23 142207.109375
24 126861.0859375
25 113488.578125
26 101788.84375
27 91523.421875
28 82485.96875
29 74501.8515625
30 67432.078125
31 61159.80078125
32 55574.27734375
33 50591.19921875
34 46134.19140625
35 42138.15625
36 38550.03125
37 35321.59765625
38 32408.068359375
39 29777.05078125
40 27396.869140625
41 25238.6484375
42 23277.82421875
43 21494.015625
44 19868.75
45 18385.412109375
46 17029.619140625
47 15789.5068359375
48 14652.9248046875
49 13609.9453125
50 12651.724609375
51 11770.2236328125
52 10957.97265625
53 10209.546875
54 9519.134765625
55 8881.5263671875
56 8291.82421875
57 7745.8212890625
58 7239.91455078125
59 6770.85986328125
60 6335.62890625
61 5931.59912109375
62 5555.9453125
63 5206.66845703125
64 4881.56689453125
65 4579.12841796875
66 4297.94482421875
67 4035.641357421875
68 3790.822021484375
69 3562.3193359375
70 3348.8330078125
71 3149.2373046875
72 2962.593994140625
73 2787.963623046875
74 2624.435546875
75 2471.22265625
76 2327.759765625
77 2193.265869140625
78 2067.25244140625
79 1949.160888671875
80 1838.314697265625
81 1734.2816162109375
82 1636.6318359375
83 1544.8223876953125
84 1458.454345703125
85 1377.2357177734375
86 1300.78857421875
87 1228.830078125
88 1161.0863037109375
89 1097.2935791015625
90 1037.1851806640625
91 980.55224609375
92 927.1463623046875
93 876.7980346679688
94 829.3092041015625
95 784.538818359375
96 742.2744750976562
97 702.4077758789062
98 664.769287109375
99 629.2286376953125
100 595.676025390625
101 563.9730834960938
102 534.0286254882812
103 505.7294616699219
104 478.99273681640625
105 453.72442626953125
106 429.8343811035156
107 407.2460021972656
108 385.88922119140625
109 365.6940002441406
110 346.5862121582031
111 328.50390625
112 311.3973083496094
113 295.2117004394531
114 279.8946533203125
115 265.39825439453125
116 251.67422485351562
117 238.68675231933594
118 226.3837890625
119 214.73158264160156
120 203.69444274902344
121 193.24545288085938
122 183.35833740234375
123 173.9917755126953
124 165.11451721191406
125 156.70401000976562
126 148.73126220703125
127 141.1772918701172
128 134.013427734375
129 127.22006225585938
130 120.7817611694336
131 114.67401123046875
132 108.88182830810547
133 103.39037322998047
134 98.18109130859375
135 93.23883056640625
136 88.55060577392578
137 84.10164642333984
138 79.88177490234375
139 75.87863159179688
140 72.07888793945312
141 68.47173309326172
142 65.05003356933594
143 61.801429748535156
144 58.71717071533203
145 55.79031753540039
146 53.01179122924805
147 50.374237060546875
148 47.86937713623047
149 45.49135971069336
150 43.23429870605469
151 41.08993148803711
152 39.05429458618164
153 37.12065887451172
154 35.284114837646484
155 33.53971862792969
156 31.8835506439209
157 30.309484481811523
158 28.815359115600586
159 27.395492553710938
160 26.046751022338867
161 24.764795303344727
162 23.547040939331055
163 22.39023208618164
164 21.291114807128906
165 20.246448516845703
166 19.253475189208984
167 18.310155868530273
168 17.413244247436523
169 16.561302185058594
170 15.750990867614746
171 14.981163024902344
172 14.249155044555664
173 13.553877830505371
174 12.892622947692871
175 12.26392650604248
176 11.666206359863281
177 11.098067283630371
178 10.557659149169922
179 10.043895721435547
180 9.555819511413574
181 9.091163635253906
182 8.649656295776367
183 8.229955673217773
184 7.8306884765625
185 7.450816631317139
186 7.08969259262085
187 6.746414661407471
188 6.419683456420898
189 6.108804702758789
190 5.81341028213501
191 5.532550811767578
192 5.265028953552246
193 5.010768413543701
194 4.768796443939209
195 4.538869380950928
196 4.319813251495361
197 4.1114301681518555
198 3.9134392738342285
199 3.724968671798706

In [18]:
loss


Out[18]:
Variable containing:
 2.6263e+07
[torch.FloatTensor of size 1]

In [19]:
loss[0]


Out[19]:
Variable containing:
 2.6263e+07
[torch.FloatTensor of size 1]

In [21]:
loss.data[0]


Out[21]:
26263210.0

nn.module


In [23]:
N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

In [25]:
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out))

loss_fn = torch.nn.MSELoss(size_average=False)

In [27]:
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    y_pread = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])


0 1398.9556884765625
1 1398.9556884765625
2 1398.9556884765625
3 1398.9556884765625
4 1398.9556884765625
5 1398.9556884765625
6 1398.9556884765625
7 1398.9556884765625
8 1398.9556884765625
9 1398.9556884765625
10 1398.9556884765625
11 1398.9556884765625
12 1398.9556884765625
13 1398.9556884765625
14 1398.9556884765625
15 1398.9556884765625
16 1398.9556884765625
17 1398.9556884765625
18 1398.9556884765625
19 1398.9556884765625
20 1398.9556884765625
21 1398.9556884765625
22 1398.9556884765625
23 1398.9556884765625
24 1398.9556884765625
25 1398.9556884765625
26 1398.9556884765625
27 1398.9556884765625
28 1398.9556884765625
29 1398.9556884765625
30 1398.9556884765625
31 1398.9556884765625
32 1398.9556884765625
33 1398.9556884765625
34 1398.9556884765625
35 1398.9556884765625
36 1398.9556884765625
37 1398.9556884765625
38 1398.9556884765625
39 1398.9556884765625
40 1398.9556884765625
41 1398.9556884765625
42 1398.9556884765625
43 1398.9556884765625
44 1398.9556884765625
45 1398.9556884765625
46 1398.9556884765625
47 1398.9556884765625
48 1398.9556884765625
49 1398.9556884765625
50 1398.9556884765625
51 1398.9556884765625
52 1398.9556884765625
53 1398.9556884765625
54 1398.9556884765625
55 1398.9556884765625
56 1398.9556884765625
57 1398.9556884765625
58 1398.9556884765625
59 1398.9556884765625
60 1398.9556884765625
61 1398.9556884765625
62 1398.9556884765625
63 1398.9556884765625
64 1398.9556884765625
65 1398.9556884765625
66 1398.9556884765625
67 1398.9556884765625
68 1398.9556884765625
69 1398.9556884765625
70 1398.9556884765625
71 1398.9556884765625
72 1398.9556884765625
73 1398.9556884765625
74 1398.9556884765625
75 1398.9556884765625
76 1398.9556884765625
77 1398.9556884765625
78 1398.9556884765625
79 1398.9556884765625
80 1398.9556884765625
81 1398.9556884765625
82 1398.9556884765625
83 1398.9556884765625
84 1398.9556884765625
85 1398.9556884765625
86 1398.9556884765625
87 1398.9556884765625
88 1398.9556884765625
89 1398.9556884765625
90 1398.9556884765625
91 1398.9556884765625
92 1398.9556884765625
93 1398.9556884765625
94 1398.9556884765625
95 1398.9556884765625
96 1398.9556884765625
97 1398.9556884765625
98 1398.9556884765625
99 1398.9556884765625
100 1398.9556884765625
101 1398.9556884765625
102 1398.9556884765625
103 1398.9556884765625
104 1398.9556884765625
105 1398.9556884765625
106 1398.9556884765625
107 1398.9556884765625
108 1398.9556884765625
109 1398.9556884765625
110 1398.9556884765625
111 1398.9556884765625
112 1398.9556884765625
113 1398.9556884765625
114 1398.9556884765625
115 1398.9556884765625
116 1398.9556884765625
117 1398.9556884765625
118 1398.9556884765625
119 1398.9556884765625
120 1398.9556884765625
121 1398.9556884765625
122 1398.9556884765625
123 1398.9556884765625
124 1398.9556884765625
125 1398.9556884765625
126 1398.9556884765625
127 1398.9556884765625
128 1398.9556884765625
129 1398.9556884765625
130 1398.9556884765625
131 1398.9556884765625
132 1398.9556884765625
133 1398.9556884765625
134 1398.9556884765625
135 1398.9556884765625
136 1398.9556884765625
137 1398.9556884765625
138 1398.9556884765625
139 1398.9556884765625
140 1398.9556884765625
141 1398.9556884765625
142 1398.9556884765625
143 1398.9556884765625
144 1398.9556884765625
145 1398.9556884765625
146 1398.9556884765625
147 1398.9556884765625
148 1398.9556884765625
149 1398.9556884765625
150 1398.9556884765625
151 1398.9556884765625
152 1398.9556884765625
153 1398.9556884765625
154 1398.9556884765625
155 1398.9556884765625
156 1398.9556884765625
157 1398.9556884765625
158 1398.9556884765625
159 1398.9556884765625
160 1398.9556884765625
161 1398.9556884765625
162 1398.9556884765625
163 1398.9556884765625
164 1398.9556884765625
165 1398.9556884765625
166 1398.9556884765625
167 1398.9556884765625
168 1398.9556884765625
169 1398.9556884765625
170 1398.9556884765625
171 1398.9556884765625
172 1398.9556884765625
173 1398.9556884765625
174 1398.9556884765625
175 1398.9556884765625
176 1398.9556884765625
177 1398.9556884765625
178 1398.9556884765625
179 1398.9556884765625
180 1398.9556884765625
181 1398.9556884765625
182 1398.9556884765625
183 1398.9556884765625
184 1398.9556884765625
185 1398.9556884765625
186 1398.9556884765625
187 1398.9556884765625
188 1398.9556884765625
189 1398.9556884765625
190 1398.9556884765625
191 1398.9556884765625
192 1398.9556884765625
193 1398.9556884765625
194 1398.9556884765625
195 1398.9556884765625
196 1398.9556884765625
197 1398.9556884765625
198 1398.9556884765625
199 1398.9556884765625
200 1398.9556884765625
201 1398.9556884765625
202 1398.9556884765625
203 1398.9556884765625
204 1398.9556884765625
205 1398.9556884765625
206 1398.9556884765625
207 1398.9556884765625
208 1398.9556884765625
209 1398.9556884765625
210 1398.9556884765625
211 1398.9556884765625
212 1398.9556884765625
213 1398.9556884765625
214 1398.9556884765625
215 1398.9556884765625
216 1398.9556884765625
217 1398.9556884765625
218 1398.9556884765625
219 1398.9556884765625
220 1398.9556884765625
221 1398.9556884765625
222 1398.9556884765625
223 1398.9556884765625
224 1398.9556884765625
225 1398.9556884765625
226 1398.9556884765625
227 1398.9556884765625
228 1398.9556884765625
229 1398.9556884765625
230 1398.9556884765625
231 1398.9556884765625
232 1398.9556884765625
233 1398.9556884765625
234 1398.9556884765625
235 1398.9556884765625
236 1398.9556884765625
237 1398.9556884765625
238 1398.9556884765625
239 1398.9556884765625
240 1398.9556884765625
241 1398.9556884765625
242 1398.9556884765625
243 1398.9556884765625
244 1398.9556884765625
245 1398.9556884765625
246 1398.9556884765625
247 1398.9556884765625
248 1398.9556884765625
249 1398.9556884765625
250 1398.9556884765625
251 1398.9556884765625
252 1398.9556884765625
253 1398.9556884765625
254 1398.9556884765625
255 1398.9556884765625
256 1398.9556884765625
257 1398.9556884765625
258 1398.9556884765625
259 1398.9556884765625
260 1398.9556884765625
261 1398.9556884765625
262 1398.9556884765625
263 1398.9556884765625
264 1398.9556884765625
265 1398.9556884765625
266 1398.9556884765625
267 1398.9556884765625
268 1398.9556884765625
269 1398.9556884765625
270 1398.9556884765625
271 1398.9556884765625
272 1398.9556884765625
273 1398.9556884765625
274 1398.9556884765625
275 1398.9556884765625
276 1398.9556884765625
277 1398.9556884765625
278 1398.9556884765625
279 1398.9556884765625
280 1398.9556884765625
281 1398.9556884765625
282 1398.9556884765625
283 1398.9556884765625
284 1398.9556884765625
285 1398.9556884765625
286 1398.9556884765625
287 1398.9556884765625
288 1398.9556884765625
289 1398.9556884765625
290 1398.9556884765625
291 1398.9556884765625
292 1398.9556884765625
293 1398.9556884765625
294 1398.9556884765625
295 1398.9556884765625
296 1398.9556884765625
297 1398.9556884765625
298 1398.9556884765625
299 1398.9556884765625
300 1398.9556884765625
301 1398.9556884765625
302 1398.9556884765625
303 1398.9556884765625
304 1398.9556884765625
305 1398.9556884765625
306 1398.9556884765625
307 1398.9556884765625
308 1398.9556884765625
309 1398.9556884765625
310 1398.9556884765625
311 1398.9556884765625
312 1398.9556884765625
313 1398.9556884765625
314 1398.9556884765625
315 1398.9556884765625
316 1398.9556884765625
317 1398.9556884765625
318 1398.9556884765625
319 1398.9556884765625
320 1398.9556884765625
321 1398.9556884765625
322 1398.9556884765625
323 1398.9556884765625
324 1398.9556884765625
325 1398.9556884765625
326 1398.9556884765625
327 1398.9556884765625
328 1398.9556884765625
329 1398.9556884765625
330 1398.9556884765625
331 1398.9556884765625
332 1398.9556884765625
333 1398.9556884765625
334 1398.9556884765625
335 1398.9556884765625
336 1398.9556884765625
337 1398.9556884765625
338 1398.9556884765625
339 1398.9556884765625
340 1398.9556884765625
341 1398.9556884765625
342 1398.9556884765625
343 1398.9556884765625
344 1398.9556884765625
345 1398.9556884765625
346 1398.9556884765625
347 1398.9556884765625
348 1398.9556884765625
349 1398.9556884765625
350 1398.9556884765625
351 1398.9556884765625
352 1398.9556884765625
353 1398.9556884765625
354 1398.9556884765625
355 1398.9556884765625
356 1398.9556884765625
357 1398.9556884765625
358 1398.9556884765625
359 1398.9556884765625
360 1398.9556884765625
361 1398.9556884765625
362 1398.9556884765625
363 1398.9556884765625
364 1398.9556884765625
365 1398.9556884765625
366 1398.9556884765625
367 1398.9556884765625
368 1398.9556884765625
369 1398.9556884765625
370 1398.9556884765625
371 1398.9556884765625
372 1398.9556884765625
373 1398.9556884765625
374 1398.9556884765625
375 1398.9556884765625
376 1398.9556884765625
377 1398.9556884765625
378 1398.9556884765625
379 1398.9556884765625
380 1398.9556884765625
381 1398.9556884765625
382 1398.9556884765625
383 1398.9556884765625
384 1398.9556884765625
385 1398.9556884765625
386 1398.9556884765625
387 1398.9556884765625
388 1398.9556884765625
389 1398.9556884765625
390 1398.9556884765625
391 1398.9556884765625
392 1398.9556884765625
393 1398.9556884765625
394 1398.9556884765625
395 1398.9556884765625
396 1398.9556884765625
397 1398.9556884765625
398 1398.9556884765625
399 1398.9556884765625
400 1398.9556884765625
401 1398.9556884765625
402 1398.9556884765625
403 1398.9556884765625
404 1398.9556884765625
405 1398.9556884765625
406 1398.9556884765625
407 1398.9556884765625
408 1398.9556884765625
409 1398.9556884765625
410 1398.9556884765625
411 1398.9556884765625
412 1398.9556884765625
413 1398.9556884765625
414 1398.9556884765625
415 1398.9556884765625
416 1398.9556884765625
417 1398.9556884765625
418 1398.9556884765625
419 1398.9556884765625
420 1398.9556884765625
421 1398.9556884765625
422 1398.9556884765625
423 1398.9556884765625
424 1398.9556884765625
425 1398.9556884765625
426 1398.9556884765625
427 1398.9556884765625
428 1398.9556884765625
429 1398.9556884765625
430 1398.9556884765625
431 1398.9556884765625
432 1398.9556884765625
433 1398.9556884765625
434 1398.9556884765625
435 1398.9556884765625
436 1398.9556884765625
437 1398.9556884765625
438 1398.9556884765625
439 1398.9556884765625
440 1398.9556884765625
441 1398.9556884765625
442 1398.9556884765625
443 1398.9556884765625
444 1398.9556884765625
445 1398.9556884765625
446 1398.9556884765625
447 1398.9556884765625
448 1398.9556884765625
449 1398.9556884765625
450 1398.9556884765625
451 1398.9556884765625
452 1398.9556884765625
453 1398.9556884765625
454 1398.9556884765625
455 1398.9556884765625
456 1398.9556884765625
457 1398.9556884765625
458 1398.9556884765625
459 1398.9556884765625
460 1398.9556884765625
461 1398.9556884765625
462 1398.9556884765625
463 1398.9556884765625
464 1398.9556884765625
465 1398.9556884765625
466 1398.9556884765625
467 1398.9556884765625
468 1398.9556884765625
469 1398.9556884765625
470 1398.9556884765625
471 1398.9556884765625
472 1398.9556884765625
473 1398.9556884765625
474 1398.9556884765625
475 1398.9556884765625
476 1398.9556884765625
477 1398.9556884765625
478 1398.9556884765625
479 1398.9556884765625
480 1398.9556884765625
481 1398.9556884765625
482 1398.9556884765625
483 1398.9556884765625
484 1398.9556884765625
485 1398.9556884765625
486 1398.9556884765625
487 1398.9556884765625
488 1398.9556884765625
489 1398.9556884765625
490 1398.9556884765625
491 1398.9556884765625
492 1398.9556884765625
493 1398.9556884765625
494 1398.9556884765625
495 1398.9556884765625
496 1398.9556884765625
497 1398.9556884765625
498 1398.9556884765625
499 1398.9556884765625

Custom nn Module


In [31]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

    
N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

model = TwoLayerNet(D_in, H, D_out)

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


0 704.9407348632812
1 652.8400268554688
2 608.1348266601562
3 569.0651245117188
4 534.7265625
5 504.0656433105469
6 476.438720703125
7 451.0474853515625
8 427.49835205078125
9 405.4925842285156
10 384.66424560546875
11 364.94561767578125
12 346.3260192871094
13 328.65045166015625
14 311.84051513671875
15 295.8565673828125
16 280.607666015625
17 266.1000061035156
18 252.32418823242188
19 239.26449584960938
20 226.8094024658203
21 214.90052795410156
22 203.5027618408203
23 192.63363647460938
24 182.24166870117188
25 172.3469696044922
26 162.90631103515625
27 153.92025756835938
28 145.37863159179688
29 137.26722717285156
30 129.55970764160156
31 122.24746704101562
32 115.29296875
33 108.70745086669922
34 102.47190856933594
35 96.5740966796875
36 91.00194549560547
37 85.74858856201172
38 80.77884674072266
39 76.08440399169922
40 71.66206359863281
41 67.48839569091797
42 63.560821533203125
43 59.86485290527344
44 56.3780517578125
45 53.093170166015625
46 50.00346374511719
47 47.09199142456055
48 44.34488296508789
49 41.76065444946289
50 39.32289505004883
51 37.02798843383789
52 34.86740493774414
53 32.83742141723633
54 30.9296932220459
55 29.130064010620117
56 27.435749053955078
57 25.839685440063477
58 24.337522506713867
59 22.92844581604004
60 21.60464859008789
61 20.361421585083008
62 19.188817977905273
63 18.087661743164062
64 17.050312042236328
65 16.072919845581055
66 15.15537166595459
67 14.2918701171875
68 13.482131004333496
69 12.721634864807129
70 12.00516414642334
71 11.331806182861328
72 10.699013710021973
73 10.102273941040039
74 9.540752410888672
75 9.011516571044922
76 8.513185501098633
77 8.044698715209961
78 7.604469299316406
79 7.189480304718018
80 6.7983784675598145
81 6.428240776062012
82 6.079920768737793
83 5.751997470855713
84 5.442923069000244
85 5.1512274742126465
86 4.875797748565674
87 4.615355968475342
88 4.370386123657227
89 4.139371395111084
90 3.921323776245117
91 3.7152864933013916
92 3.520789861679077
93 3.3373801708221436
94 3.1642515659332275
95 3.0007948875427246
96 2.8463776111602783
97 2.700451612472534
98 2.562469244003296
99 2.4319286346435547
100 2.3083183765411377
101 2.191359281539917
102 2.0809099674224854
103 1.976595163345337
104 1.8779417276382446
105 1.7845367193222046
106 1.6961863040924072
107 1.612591028213501
108 1.53348970413208
109 1.4586732387542725
110 1.3879107236862183
111 1.3209706544876099
112 1.2575525045394897
113 1.1974167823791504
114 1.1404786109924316
115 1.0864275693893433
116 1.0350898504257202
117 0.9863788485527039
118 0.9401412010192871
119 0.896259605884552
120 0.8545987010002136
121 0.8150573968887329
122 0.7774476408958435
123 0.7416736483573914
124 0.7077035307884216
125 0.6754247546195984
126 0.6447378396987915
127 0.615578830242157
128 0.5877849459648132
129 0.561370849609375
130 0.5362561941146851
131 0.5123559236526489
132 0.4895749092102051
133 0.4679161608219147
134 0.44729170203208923
135 0.42765408754348755
136 0.40891608595848083
137 0.39105284214019775
138 0.37406036257743835
139 0.35784912109375
140 0.3423929810523987
141 0.32764941453933716
142 0.31360793113708496
143 0.3002154231071472
144 0.28741639852523804
145 0.2752075493335724
146 0.26355940103530884
147 0.2524354159832001
148 0.2418164610862732
149 0.23168328404426575
150 0.22199808061122894
151 0.21274970471858978
152 0.2039027214050293
153 0.19544537365436554
154 0.18734878301620483
155 0.17961208522319794
156 0.172221377491951
157 0.1651483029127121
158 0.1583864390850067
159 0.1519290953874588
160 0.14573998749256134
161 0.13981814682483673
162 0.13416098058223724
163 0.12874101102352142
164 0.12355363368988037
165 0.11859158426523209
166 0.1138322725892067
167 0.10927708446979523
168 0.10491390526294708
169 0.10073729604482651
170 0.09673161059617996
171 0.09289216995239258
172 0.08921077847480774
173 0.08568665385246277
174 0.08230651170015335
175 0.07906626164913177
176 0.07596103101968765
177 0.07298443466424942
178 0.07012689858675003
179 0.0673920214176178
180 0.06476661562919617
181 0.06224823370575905
182 0.05983397364616394
183 0.05751638859510422
184 0.05529284477233887
185 0.05316101387143135
186 0.05111568793654442
187 0.0491514578461647
188 0.04726763069629669
189 0.045460451394319534
190 0.0437251515686512
191 0.042057573795318604
192 0.04045636206865311
193 0.038919392973184586
194 0.0374443493783474
195 0.03602704405784607
196 0.03466631472110748
197 0.03335927054286003
198 0.03210316225886345
199 0.030896373093128204
200 0.029737140983343124
201 0.02862221747636795
202 0.027550453320145607
203 0.02652229741215706
204 0.02553378976881504
205 0.024582702666521072
206 0.02366844192147255
207 0.02278982289135456
208 0.021945370361208916
209 0.021133193746209145
210 0.020353535190224648
211 0.0196025799959898
212 0.018880479037761688
213 0.018185772001743317
214 0.017517905682325363
215 0.016875600442290306
216 0.01625729165971279
217 0.015662629157304764
218 0.01509043388068676
219 0.014539840631186962
220 0.014010182581841946
221 0.013500729575753212
222 0.013010199181735516
223 0.012538616545498371
224 0.012084848247468472
225 0.011647390201687813
226 0.011226712726056576
227 0.010822208598256111
228 0.010432623326778412
229 0.010057469829916954
230 0.009696212597191334
231 0.009348739869892597
232 0.009014279581606388
233 0.008691804483532906
234 0.008381308987736702
235 0.008082120679318905
236 0.007794159930199385
237 0.007516804616898298
238 0.007249636575579643
239 0.006992410868406296
240 0.0067444597370922565
241 0.006505695637315512
242 0.006275609135627747
243 0.006053978111594915
244 0.005840450059622526
245 0.005634570959955454
246 0.00543615035712719
247 0.005244977306574583
248 0.0050608450546860695
249 0.004883205518126488
250 0.004712111782282591
251 0.004547155927866697
252 0.00438801059499383
253 0.004234801046550274
254 0.004087138921022415
255 0.00394468056038022
256 0.0038073535542935133
257 0.003674916224554181
258 0.003547340864315629
259 0.0034241785760968924
260 0.0033055213280022144
261 0.0031909740064293146
262 0.003080464666709304
263 0.00297395046800375
264 0.0028712681960314512
265 0.0027722008526325226
266 0.002676586853340268
267 0.0025844546034932137
268 0.00249566906131804
269 0.002409928245469928
270 0.0023272379767149687
271 0.0022473831195384264
272 0.0021703445818275213
273 0.0020960387773811817
274 0.002024322748184204
275 0.0019551957957446575
276 0.0018885227618739009
277 0.0018241048092022538
278 0.0017619712743908167
279 0.0017019875813275576
280 0.001644099480472505
281 0.0015882451552897692
282 0.0015343326376751065
283 0.0014823161764070392
284 0.0014320671325549483
285 0.001383612398058176
286 0.001336805522441864
287 0.001291608321480453
288 0.001247989945113659
289 0.0012058971915394068
290 0.0011652554385364056
291 0.0011259930906817317
292 0.00108811364043504
293 0.0010515094036236405
294 0.0010161867830902338
295 0.0009820872219279408
296 0.0009491592063568532
297 0.0009173551225103438
298 0.0008866394637152553
299 0.0008569861529394984
300 0.0008283648057840765
301 0.0008007331052795053
302 0.0007740394212305546
303 0.0007482757791876793
304 0.0007233863580040634
305 0.0006993439164943993
306 0.0006761051481589675
307 0.0006536853034049273
308 0.0006320223910734057
309 0.0006110715330578387
310 0.0005908329621888697
311 0.0005713051650673151
312 0.0005524358130060136
313 0.0005341876531019807
314 0.0005165667971596122
315 0.0004995397757738829
316 0.00048308740952052176
317 0.0004672037612181157
318 0.00045185041381046176
319 0.0004370079841464758
320 0.00042265403317287564
321 0.0004087786073796451
322 0.0003953759151045233
323 0.00038242238224484026
324 0.00036989664658904076
325 0.0003577994357328862
326 0.0003461156738922
327 0.0003348116879351437
328 0.0003238721110392362
329 0.00031331527861766517
330 0.0003031031519640237
331 0.00029322822229005396
332 0.0002836765197571367
333 0.0002744541852734983
334 0.0002655418065842241
335 0.00025692093186080456
336 0.00024859144468791783
337 0.00024052200024016201
338 0.0002327319234609604
339 0.00022519633057527244
340 0.00021791360632050782
341 0.00021087010100018233
342 0.0002040600375039503
343 0.0001974744809558615
344 0.00019109653658233583
345 0.00018494200776331127
346 0.00017898094665724784
347 0.0001732154778437689
348 0.00016764413157943636
349 0.00016225839499384165
350 0.00015704293036833405
351 0.0001519965153420344
352 0.0001471248542657122
353 0.00014241198368836194
354 0.00013784637849312276
355 0.00013343030877877027
356 0.00012916258128825575
357 0.00012502979370765388
358 0.000121036842756439
359 0.00011717302550096065
360 0.00011343348160153255
361 0.00010981606465065852
362 0.00010631549230311066
363 0.000102926031104289
364 9.965019853552803e-05
365 9.647956176195294e-05
366 9.341099939774722e-05
367 9.044147736858577e-05
368 8.757100295042619e-05
369 8.479260577587411e-05
370 8.210370287997648e-05
371 7.950419239932671e-05
372 7.698323315707967e-05
373 7.454615115420893e-05
374 7.21895630704239e-05
375 6.99073716532439e-05
376 6.770122854504734e-05
377 6.556375592481345e-05
378 6.349416798911989e-05
379 6.14887394476682e-05
380 5.9553833125391975e-05
381 5.7675530115375295e-05
382 5.5859622079879045e-05
383 5.4102332796901464e-05
384 5.2401243010535836e-05
385 5.075326771475375e-05
386 4.916162652079947e-05
387 4.761672607855871e-05
388 4.612157499650493e-05
389 4.467491089599207e-05
390 4.3275358621031046e-05
391 4.1920346120605245e-05
392 4.060834544361569e-05
393 3.933636617148295e-05
394 3.8108650187496096e-05
395 3.691560050356202e-05
396 3.576268136384897e-05
397 3.4647484426386654e-05
398 3.3567106584087014e-05
399 3.25204455293715e-05
400 3.150699922116473e-05
401 3.0524552130373195e-05
402 2.9574757718364708e-05
403 2.8654287234530784e-05
404 2.7764761398429982e-05
405 2.6901867386186495e-05
406 2.6066109057865106e-05
407 2.525615127524361e-05
408 2.4471668439218774e-05
409 2.371237860643305e-05
410 2.2978734705247916e-05
411 2.2266083760769106e-05
412 2.15761592698982e-05
413 2.090929410769604e-05
414 2.026238144026138e-05
415 1.9635857825051062e-05
416 1.9028007955057546e-05
417 1.8439952327753417e-05
418 1.7872249372885562e-05
419 1.7319864127784967e-05
420 1.6785656043794006e-05
421 1.6268346371361986e-05
422 1.576770046085585e-05
423 1.5282481399481185e-05
424 1.4811957953497767e-05
425 1.4356187421071809e-05
426 1.3914833289163653e-05
427 1.3487537216860801e-05
428 1.3073318768874742e-05
429 1.2671499462157954e-05
430 1.2283097930776421e-05
431 1.1906363397429232e-05
432 1.154173423856264e-05
433 1.1187486961716786e-05
434 1.0845144061022438e-05
435 1.0512251719774213e-05
436 1.0191248293267563e-05
437 9.87921066553099e-06
438 9.576732736604754e-06
439 9.284656698582694e-06
440 9.001995749713387e-06
441 8.726864507480059e-06
442 8.460524441034067e-06
443 8.201715900213458e-06
444 7.95229607319925e-06
445 7.709832061664201e-06
446 7.475702204828849e-06
447 7.247590474435128e-06
448 7.026540515653323e-06
449 6.81324263496208e-06
450 6.606455826840829e-06
451 6.404876330634579e-06
452 6.210942501638783e-06
453 6.022199158906005e-06
454 5.83955124966451e-06
455 5.662511739501497e-06
456 5.4904444368730765e-06
457 5.3236990424920805e-06
458 5.162190973351244e-06
459 5.006709670851706e-06
460 4.855074621445965e-06
461 4.7086505219340324e-06
462 4.566180450638058e-06
463 4.428090505825821e-06
464 4.2939959712384734e-06
465 4.1646590034361e-06
466 4.038670795125654e-06
467 3.9166538954305e-06
468 3.79907669412205e-06
469 3.684141802295926e-06
470 3.573150706870365e-06
471 3.465497684373986e-06
472 3.3608137073315447e-06
473 3.2596528853900963e-06
474 3.1616882552043535e-06
475 3.066701765419566e-06
476 2.974650669784751e-06
477 2.8847371140727773e-06
478 2.7981843686575303e-06
479 2.7147978016728302e-06
480 2.6329585125495214e-06
481 2.5534793621773133e-06
482 2.4773707991698757e-06
483 2.4028249754337594e-06
484 2.331420546397567e-06
485 2.2610458927374566e-06
486 2.1931941773800645e-06
487 2.1276539428072283e-06
488 2.0638226487790234e-06
489 2.0021216187160462e-06
490 1.942234121088404e-06
491 1.883944264591264e-06
492 1.828026711336861e-06
493 1.7733487993609742e-06
494 1.7204139339810354e-06
495 1.669164703343995e-06
496 1.6192320799746085e-06
497 1.5711593732703477e-06
498 1.5242372910506674e-06
499 1.4792458387091756e-06

In [ ]: